import os
import cv2
import time
from imageio import save
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.ecbsr import ECBSR

def process_single_video():
    scale = 2
    kernel = np.array([[0, -1, 0],
                       [-1, 5, -1],
                       [0, -1, 0]])
    net = ECBSR(4, 8, 0, 'prelu', scale, 1)
    net.load_state_dict(torch.load("E:\\gitcode\\ecbsr\\ecbsr-main\\experiments\\ecbsr-x2-m4c8-prelu-2025-0316-2115\\models\\model_x2_29.pt", map_location="cpu"))
    net.eval()
    cap = cv2.VideoCapture("vidyo1_640x360.mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter("vidyo1_x2_720p.mp4", fourcc, 15, (1280, 720))

    while(1):
        ret, frame = cap.read()
        if ret:
            t1 = time.time()
            bgr_img = frame.copy()
            h, w = bgr_img.shape[:2]
            ycbcr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2YCrCb).astype(np.float32)#.transpose(1,2,0)
            y_ = ycbcr_img[..., 0] 
            y_T = torch.from_numpy(y_)
            y_T = y_T.unsqueeze(0).unsqueeze(0)

            with torch.no_grad():
                pred = net(y_T).clamp(0, 255)
            pred = pred.cpu().numpy().squeeze(0).squeeze(0)
            ycbcr_image_bicubic = cv2.resize(ycbcr_img, (w*scale, h*scale))
            ycbcr_image_res = ycbcr_image_bicubic.copy()
            ycbcr_image_bicubic = ycbcr_image_bicubic.astype(np.uint8)
            ycbcr_image_res[..., 0] = pred
            ycbcr_image_res = ycbcr_image_res.astype(np.uint8)

            ycbcr_image_res = cv2.cvtColor(ycbcr_image_res, cv2.COLOR_YCrCb2BGR)
            ycbcr_image_res_sharpen = cv2.filter2D(ycbcr_image_res, -1, kernel)
            ycbcr_image_res = cv2.addWeighted(ycbcr_image_res,0.5, ycbcr_image_res_sharpen, 0.5, 0)
            ycbcr_image_bicubic = cv2.cvtColor(ycbcr_image_bicubic, cv2.COLOR_YCrCb2BGR)    
            t2 = time.time()
            print("time cost::::", t2-t1)
            video_writer.write(ycbcr_image_res)
        else:
            break

    video_writer.release()
    
if __name__=="__main__":
    process_single_video()